import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from patsy import dmatrix
from statsmodels.gam.api import GLMGam, BSplines
from scipy.linalg import solve
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from scipy.linalg import solve
import torch
import torch.nn as nn
from scipy.stats import multivariate_t
from scipy.stats import multivariate_normal
from sklearn.ensemble import RandomForestRegressor
import argparse
from scipy.optimize import minimize###
def parse_args():

    return parser.parse_args()


# Generate next stats
def gen_nextStats(data, nS):
    n_time = data.shape[0]
    num_time = len(data['time'].unique())
    
    S_columns = [f"S{i+1}" for i in range(nS)]#####
    S = data[S_columns].values
    
    NextS = np.vstack([S[1:], np.zeros((1, nS))])
    
    last_time_position = data.index[data['time'] == num_time]
    NextS[last_time_position, :] = 0
    
    next_S_df = pd.DataFrame(NextS, columns=[f"next.S{i+1}" for i in range(nS)])
    data = pd.concat([data, next_S_df], axis=1)
    
    return data

# Create phi basis##
def phi_basis(X):
    phi_vector = np.hstack([np.column_stack([X[:, i], X[:, i]**2, X[:, i]**3]) for i in range(X.shape[1])])
    return phi_vector

# Behavior estimation
def behavior_est(data, treatment, method='All'):###treatment
    nS = len([col for col in data.columns if 'S' in col]) // 2##
    covariate_names = [f"S{i+1}" for i in range(nS)]
    
    X = dmatrix(" + ".join(covariate_names), data, return_type='dataframe')####
    y = data['action']
    
    if np.all(y == 0):####
        # If y is all zeros, directly set the pi_t column to 1
        data['pi_t'] = np.ones(len(data)) * 1.0
        return data
    
    if method == 'All':
        model = LogisticRegression().fit(X, y)
        pre_prob = model.predict_proba(X)[:, 1]
        data['pi_t'] = data['action'] * pre_prob + (1 - data['action']) * (1 - pre_prob)####
        new_data = data
    else:
        new_data = pd.concat([
            behavior_est(subdata, treatment, method='All')
            for _, subdata in data.groupby('time')
        ]).sort_values('user')######

    return new_data


def mu_t_est(data, treatment):
    nS = len([col for col in data.columns if 'S' in col]) // 2
    time_num = len(data['time'].unique())#
    n_num = len(data['user'].unique())##
    
    init_position = data.index[data['time'] == 1]
    
    data['omega_t'] = 0
    data.loc[init_position, 'omega_t'] = 1
    
    lambda_t = data.groupby('user')['eta_t'].cumprod()##
    data['lambda_t'] = lambda_t
    
    lambda_t_before = np.hstack([0, lambda_t[:-1]])
    lambda_t_response = lambda_t_before[data['time'] != 1]
    
    S = data[[f"S{i+1}" for i in range(nS)]].values
    covariate_names = [f"S{i+1}" for i in range(nS)]
    
    spline_terms = BSplines(S, df=[6]*nS, degree=[3]*nS)
    formula = f"lambda_t_response ~ {' + '.join([f's({col})' for col in covariate_names])}"
    
    gam_model = GLMGam.from_formula(formula, data=data, smoother=spline_terms)
    fitted_model = gam_model.fit()
    
    omega_t_est = fitted_model.fittedvalues
    omega_t_est[omega_t_est < 0] = 0
    data.loc[data['time'] != 1, 'omega_t'] = omega_t_est
    
    mu_t_est = data['eta_t'] * data['omega_t']
    mu_t_est[mu_t_est < 0] = 0
    
    mu_t_matrix = mu_t_est.values.reshape(-1, n_num)
    mu_t_mean = mu_t_est / np.mean(mu_t_matrix, axis=1)######
    
    data['mu_t'] = mu_t_mean
    
    return {'mu_t_est': mu_t_mean, 'omega_t_est': data['omega_t'], 'phi_S_eta_mu_all': None}

# Backward estimation function
def mu_t_est_backward(data, treatment):
    nS = len([col for col in data.columns if 'S' in col]) // 2
    time_num = len(data['time'].unique())
    n_time = data.shape[0]
    
    mu_t_est = np.zeros(n_time)
    S = data[[f"S{i+1}" for i in range(nS)]].values
    next_S = data[[f"next.S{i+1}" for i in range(nS)]].values
    
    phi_S = phi_basis(S)##
    phi_next_S = phi_basis(next_S)##
    
    phi_S_eta = phi_S * data['eta_t'].values[:, None]###
    
    action_pos = data.index[data['action'] == treatment]
    mu_t_est[action_pos] = np.dot(phi_S, solve(np.dot(phi_S.T, phi_S), phi_S_eta.sum(axis=0)))
    #####np.dot(phi_S.T, phi_S)
    phi_S_eta_mu_all = phi_S_eta.sum(axis=0)
    #####
    for k in range(2, time_num+1):
        time_pos = data.index[data['time'] >= k]
        temp_pos = np.intersect1d(action_pos, time_pos)
        
        phi_S_eta_mu = (phi_S[temp_pos] * data['eta_t'].iloc[temp_pos].values[:, None] *
                        mu_t_est[temp_pos-1][:, None])
        
        gamma_k = solve(np.dot(phi_S[temp_pos].T, phi_S[temp_pos]), phi_S_eta_mu.sum(axis=0))
        mu_k_est = np.dot(phi_S, gamma_k)
        
        mu_t_est[temp_pos] = mu_k_est[temp_pos]
        
        phi_S_eta_mu_all = np.column_stack([phi_S_eta_mu_all, phi_S_eta_mu.sum(axis=0)])
    
    mu_t_est[mu_t_est < 0] = 0
    mu_t_matrix = mu_t_est.reshape(-1, len(data['user'].unique()))
    mu_t_est = mu_t_est / np.mean(mu_t_matrix, axis=1)
    
    return {'mu_t_est': mu_t_est, 'phi_S_eta_mu_all': phi_S_eta_mu_all}


# omega_t_est_backward function
def omega_t_est_backward(data0, phi_S_eta_mu_all, n_D, n_H, Q_indicator='Linear'):
    nS = len([col for col in data0.columns if 'S' in col]) // 2
    time_num = len(data0['time'].unique())
    n_num = len(data0['user'].unique())
    n_time = len(data0)
    
    omega_t_est = np.zeros(n_time)
    
    S = data0[[f"S{i+1}" for i in range(nS)]].values
    next_S = data0[[f"next.S{i+1}" for i in range(nS)]].values
    
    if Q_indicator == 'Linear':
        phi_S = S
        phi_next_S = next_S      ###
    else:
        phi_S = phi_basis(S)
        phi_next_S = phi_basis(next_S)
    
    for k in range(1, time_num+1):
        position_temp = data0['time'] >= k
        gamma_k = solve(np.dot(phi_S[position_temp].T, phi_S[position_temp]), phi_S_eta_mu_all[:, k-1]) * n_H / n_D
        omega_k_est = np.dot(phi_S, gamma_k)
        position_k = data0['time'] == k
        omega_t_est[position_k] = omega_k_est[position_k]######在
    
    # Check positivity and normalization
    omega_t_est[omega_t_est < 0] = 0#####
    omega_t_matrix = omega_t_est.reshape(-1, n_num)
    omega_t_est /= np.mean(omega_t_matrix, axis=1)
    
    return omega_t_est


def Q_eta_est_backward(data, treatment, prob_behavior, Q_indicator='Linear'):
    # If prob_behavior is None, then execute behavior_est
    if prob_behavior is None:   #########
        data = behavior_est(data, treatment)
    else:
        data['pi_t'] = np.where(data['action'] == 1, prob_behavior, 1 - prob_behavior)
        ###

    eta_t = np.where(data['action'] == treatment, 1/data['pi_t'], 0)###
    data['eta_t'] = eta_t
    
    nS = len([col for col in data.columns if 'S' in col]) // 2
    
    position = data['action'] == treatment######
    n_time = len(data)
    time_num = len(data['time'].unique())##=
    
    S = data[[f"S{i+1}" for i in range(nS)]].values
    next_S = data[[f"next.S{i+1}" for i in range(nS)]].values

    # Process features
    if Q_indicator == 'Linear':
        phi_S = S
        phi_next_S = next_S
    else:
        phi_S = phi_basis(S)
        phi_next_S = phi_basis(next_S)

    # Calculate fitted_k and initial fitted_value
    position_k = data['time'] == time_num
    position_tmp = position_k & position  ####
    X = phi_S[position_tmp]######
    y = data['reward'][position_tmp]###
    model = RandomForestRegressor(n_estimators=100, random_state=42)  # 
    model.fit(X, y)


    fitted_value=model.predict(phi_S)####
    fitted_q_hat = fitted_value.copy()#########
    fitted_q_next = np.zeros_like(fitted_value)####

    # mu_t estimation
    if time_num > 1:
        mu_t_estimation = mu_t_est(data, treatment)
        data['mu_t'] = mu_t_estimation['mu_t_est']##############
        phi_S_eta_mu_all = mu_t_estimation['phi_S_eta_mu_all']
    else:
        data['mu_t'] = data['eta_t']#####
        phi_S_eta_mu_all = None

    # Calculate eta
    data['fitted_q_hat'] = fitted_q_hat######
    data['fitted_q_next'] = fitted_q_next#####

    data['eta'] = data['mu_t'] * (data['reward'] + data['fitted_q_next'] - data['fitted_q_hat'])########
    data['eta'] = data['eta'] / time_num###

    data['fitted_value'] = fitted_value ######  
    ####
    # Calculate value0 and eta_est
    value0 = np.mean(fitted_value[position_k]) / time_num########
    eta_est = value0 + data.groupby('user')['eta'].sum().mean()##

    # TD_user and variance
    TD_user = fitted_value[position_k] / time_num + np.array(data.groupby('user')['eta'].sum()).reshape([-1,1]) - eta_est
    variance = np.mean(TD_user ** 2)  

    TD_user_value0 = fitted_value[position_k] / time_num - value0
    var_value0 = np.mean(TD_user_value0 ** 2) 
    return {
        'eta_est': eta_est,
        'value0': value0,
        'var_value0': var_value0,
        'TD_user_value0': TD_user_value0,
        'TD_user': TD_user,
        'variance': variance,
        'data': data,
        'model':model,##
        'phi_S':phi_S##
    }
def eta_est_historical(result_D, result_H, n_D, n_H, time_num, ratio_indicator='Given'):
    position_1_D = result_D['data']['time'] == 1
    fitted_value_for_D =result_H['model'].predict(result_D['phi_S'])
    result_D['data']['eta_est_for_D_part_1'] = np.zeros(len(result_D['data']))
    result_D['data'].loc[position_1_D, 'eta_est_for_D_part_1'] = fitted_value_for_D / time_num    
    result_H['data']['eta_est_for_D_part_2'] = np.array(result_H['data'].groupby('user')['eta'].sum())
    
    eta_est_for_D_part_1 = np.mean(fitted_value_for_D) / time_num
    eta_est_for_D_part_2 = np.mean(result_H['data'].groupby('user')['eta'].sum())  
    
    if ratio_indicator == 'Given':
        omega_t_est = np.ones(n_H * time_num)###
    else:
        omega_t_est = omega_t_est_backward(result_H['data'], result_D['phi_S_eta_mu_all'], n_D, n_H)
    
    TD_user_for_D_part_1 = fitted_value_for_D / time_num - eta_est_for_D_part_1#
    result_H['data']['rho_t'] = omega_t_est * (result_H['data']['reward'] + result_H['data']['fitted_q_next'] - result_H['data']['fitted_q_hat']) / time_num
    TD_user_for_D_part_2 = result_H['data'].groupby('user')['eta'].sum() - eta_est_for_D_part_2
    variance_for_D_part_1 = np.mean(TD_user_for_D_part_1 ** 2)##
    variance_for_D_part_2 = np.mean(TD_user_for_D_part_2 ** 2)
    cov_D_H = np.mean(TD_user_for_D_part_1 * result_D['TD_user']) / n_D
    eta_H_for_D = eta_est_for_D_part_1 + eta_est_for_D_part_2##
    var_H_for_D = variance_for_D_part_1 / n_D + variance_for_D_part_2 / n_H##
    df_D = pd.DataFrame(); df_H = pd.DataFrame()
    df_D['S1'] = result_D['data']['S1']*1.0
    df_D['dr_est_t0'] = result_D['data']['eta'] + result_D['data']['fitted_value']##
    df_D['eta_est_for_D_part_1'] = result_D['data']['eta_est_for_D_part_1']
    df_H['S1'] = result_H['data']['S1']*1.0
    df_H['eta_est_for_D_part_2'] = result_H['data']['eta_est_for_D_part_2']* 1.0
    return {
        'eta_est': eta_H_for_D,##
        'var_H_for_D': var_H_for_D,
        'TD_user_for_D_part_1': TD_user_for_D_part_1,
        'df_D': df_D, 
        'df_H': df_H}###

# pessimist_combine function
def pessimist_combine(result_D_1, result_D_0, result_H, n_D, n_H, time_num, ratio_indicator='Given'):
    eta_D_1 = result_D_1['eta_est']
    eta_D_0 = result_D_0['eta_est']
    eta_H = result_H['eta_est']
    eta_est_hist = eta_est_historical(result_D_0, result_H, n_D, n_H, time_num, ratio_indicator)
    
    eta_H_for_D = eta_est_hist['eta_est']######
    
    ATE_e = eta_D_1 - eta_D_0##
    ATE_h = eta_D_1 - eta_H_for_D##
    
    cov_D_1_0 = np.mean(result_D_1['TD_user'] * result_D_0['TD_user'])##
    cov_D_1_H = np.mean(result_D_1['TD_user'] * eta_est_hist['TD_user_for_D_part_1'])
    
    var_D = (result_D_1['variance'] + result_D_0['variance'] - 2 * cov_D_1_0) / n_D
    var_H = (result_D_1['variance'] - 2 * cov_D_1_H) / n_D + eta_est_hist['var_H_for_D']
    cov_D_H = np.mean((result_D_1['TD_user'] - result_D_0['TD_user']) * (result_D_1['TD_user'] - eta_est_hist['TD_user_for_D_part_1'])) / n_D
    
    bias_square_UB_for_D = (abs(eta_D_0 - eta_H_for_D) + 1.64 * np.sqrt(var_D + var_H - 2 * cov_D_H)) ** 2##
    
    weight_for_D = (var_H + bias_square_UB_for_D - cov_D_H) / (var_D + var_H + bias_square_UB_for_D - 2 * cov_D_H)###
    ATE_pessi_DR = weight_for_D * ATE_e + (1 - weight_for_D) * ATE_h###
    
    # Minimize MSE在这里是直接对于两个estimato
    bias_square = (eta_D_0 - eta_H_for_D) ** 2
    weight_1 = (var_H + bias_square - cov_D_H) / (var_D + var_H + bias_square - 2 * cov_D_H)####
    ATE_mse = weight_1 * ATE_e + (1 - weight_1) * ATE_h##: 
    
    # L1 penaltyx
    weight_2 = (var_D - 0.4 * bias_square) / (var_D + var_H)
    weight_2 = np.clip(weight_2, 0, 1)######
    
    ATE_L1 = (1 - weight_2) * ATE_e + weight_2 * ATE_h###
    
    # Testing-based decision
    tstat = (eta_D_0 - eta_H_for_D) / np.sqrt(var_D + var_H - 2 * cov_D_H)##
    
    if abs(tstat) > norm.ppf(0.975):
        ATE_test005 = ATE_e
    else:
        ATE_test005 = ATE_pessi_DR
    
    return np.array([ATE_e, ATE_h, ATE_pessi_DR, ATE_mse, ATE_L1])

def get_oracle_w(result_D_1, result_D_0, result_H, n_D, n_H, time_num, ratio_indicator='Given', est_n_D=10):
    eta_D_1 = result_D_1['eta_est']
    eta_D_0 = result_D_0['eta_est']
    eta_H = result_H['eta_est']
    
    eta_est_hist = eta_est_historical(result_D_0, result_H, n_D, n_H, time_num, ratio_indicator)
    
    eta_H_for_D = eta_est_hist['eta_est']#
    
    ATE_e = eta_D_1 - eta_D_0###
    ATE_h = eta_D_1 - eta_H_for_D####
    
    cov_D_1_0 = np.mean(result_D_1['TD_user'] * result_D_0['TD_user'])
    cov_D_1_H = np.mean(result_D_1['TD_user'] * eta_est_hist['TD_user_for_D_part_1'])
    
    var_D = ((n_D)/est_n_D) * ((result_D_1['variance'] + result_D_0['variance'] - 2 * cov_D_1_0) / est_n_D)
    var_H = ((n_D)/est_n_D) * ((result_D_1['variance'] - 2 * cov_D_1_H) / n_D + eta_est_hist['var_H_for_D'])
    
    cov_D_H = ((n_D)/est_n_D) * (np.mean((result_D_1['TD_user'] - result_D_0['TD_user']) * (result_D_1['TD_user'] - eta_est_hist['TD_user_for_D_part_1'])) / n_D)
    ###在
    # Minimize MSE
    bias_square = (eta_D_0 - eta_H_for_D) ** 2
    weight_oracle = (var_H + bias_square - cov_D_H) / (var_D + var_H + bias_square - 2 * cov_D_H)

    return weight_oracle


# Combine function in Python
def combine(result_D_1, result_D_0, result_H, n_D, n_H, time_num, ratio_indicator='Given',
           lr=0.1, epochs=5000, weight_oracle=None, hidden_size=10):#####
    eta_D_1 = result_D_1['eta_est']##
    eta_D_0 = result_D_0['eta_est']##
    eta_H = result_H['eta_est']##
    
    eta_est_hist = eta_est_historical(result_D_0, result_H, n_D, n_H, time_num, ratio_indicator)##所
    
    df_D = eta_est_hist['df_D']
    df_H = eta_est_hist['df_H']
    df_D['dr_est_t1'] = result_D_1['data']['eta'] + result_D_1['data']['fitted_value']
    

    used_model_logit = LogitTransform()##
    ATE_func_w_logit = optimize_parameters(used_model_logit, loss_func=get_loss,
                                     df_D=df_D, df_H=df_H, lr=lr, epochs=epochs)#######
    used_model_ub_logit = LogitTransform()####
    ATE_func_w_ub_logit = optimize_parameters(used_model_ub_logit, loss_func=get_loss_ub, df_D=df_D, df_H=df_H, lr=lr, epochs=epochs)


    used_model_nn = NeuralNetwork(hidden_size=hidden_size)
    ATE_func_w_nn = optimize_parameters(used_model_nn, loss_func=get_loss,
                                     df_D=df_D, df_H=df_H, lr=lr, epochs=epochs)##
    used_model_ub_nn = NeuralNetwork(hidden_size=hidden_size)
    ATE_func_w_ub_nn = optimize_parameters(used_model_ub_nn, loss_func=get_loss_ub, df_D=df_D, df_H=df_H, lr=lr, epochs=epochs)

    
    ## Ordinary method
    eta_H_for_D = eta_est_hist['eta_est']###
    
    ATE_e = eta_D_1 - eta_D_0###
    ATE_h = eta_D_1 - eta_H_for_D####
    
    cov_D_1_0 = np.mean(result_D_1['TD_user'] * result_D_0['TD_user'])####
    cov_D_1_H = np.mean(result_D_1['TD_user'] * eta_est_hist['TD_user_for_D_part_1'])
    ########
    var_D = (result_D_1['variance'] + result_D_0['variance'] - 2 * cov_D_1_0) / n_D ####
    var_H = (result_D_1['variance'] - 2 * cov_D_1_H) / n_D + eta_est_hist['var_H_for_D']###
    
    cov_D_H = np.mean((result_D_1['TD_user'] - result_D_0['TD_user']) * (result_D_1['TD_user'] - eta_est_hist['TD_user_for_D_part_1'])) / n_D
    ####

    ####SPE-varbased
    weight_for_SPED=(var_H-cov_D_H)/(var_D + var_H - 2 * cov_D_H+0.000001)##
    ATE_MVE= weight_for_SPED*ATE_e+(1 -weight_for_SPED) * ATE_h####
    bias_squared=(eta_D_0 - eta_H_for_D)**2
    lambda_values = [0.2]
    lasso_ATE_results = {}
    for lambda_reg in lambda_values:
        def objective(w):
            return (w**2 * var_H + (1 - w)**2 * var_D + 
                    2 * w * (1 - w) * cov_D_H + lambda_reg * abs(w) * bias_squared)

        w0 = 0.5
        result = minimize(objective, w0, method='L-BFGS-B')
        optimal_w = result.x[0]
        ATE_lasso =(1-optimal_w) * ATE_e + (optimal_w) * ATE_h
        lasso_ATE_results[lambda_reg] = ATE_lasso
    bias_square_UB_for_D = (abs(eta_D_0 - eta_H_for_D) + 1.64 * np.sqrt(var_D + var_H - 2 * cov_D_H))**2

    weight_for_D = (var_H + bias_square_UB_for_D - cov_D_H) / (var_D + var_H + bias_square_UB_for_D - 2 * cov_D_H)

    ATE_pessi_DR = weight_for_D * ATE_e + (1 - weight_for_D) * ATE_h
    bias_square = (eta_D_0 - eta_H_for_D) ** 2
    weight_1 = (var_H + bias_square - cov_D_H) / (var_D + var_H + bias_square - 2 * cov_D_H)
    ATE_mse = weight_1 * ATE_e + (1 - weight_1) * ATE_h
    ATEs = np.array([
    ATE_e, ATE_h, ATE_pessi_DR,  # 
    ATE_func_w_logit, ATE_func_w_ub_logit,  # 
    *list(lasso_ATE_results.values()),  # 
    ATE_MVE  # 
])
  
    return ATEs


# Historical value estimation function in Python
def eta_est_historical_value(result_D, result_H, n_D, n_H, time_num, ratio_indicator='Given'):
    position_1_D = result_D['data']['time'] == 1
    fitted_value_for_D = np.dot(result_D['variable'], result_H['last_coe'])[position_1_D]
    
    eta_est_for_D_part_1 = np.mean(fitted_value_for_D) / time_num
    TD_user_for_D_part_1 = fitted_value_for_D / time_num - eta_est_for_D_part_1
    
    variance_for_D_part_1 = np.mean(TD_user_for_D_part_1 ** 2)
    
    eta_H_for_D = eta_est_for_D_part_1
    var_H_for_D = variance_for_D_part_1 / n_D  
    
    return {
        'eta_est': eta_H_for_D,
        'var_H_for_D': var_H_for_D,
        'TD_user_for_D_part_1': TD_user_for_D_part_1,
        'fitted_value_for_D': fitted_value_for_D
    }

def rwrd_gen(S, a, S_next, error_tmp, t, d, mu_diff, b, delta, typ):
    if typ==1:
        conditions = [S < -1, (S >= -1) & (S < 0),
            S > 0 ]
        choices_d_S = [-1, -1, 2]  # 
        choices_d_mu = [0, 1, 1]  #
        d_S = d * np.select(conditions, choices_d_S)
        d_mu = d * np.select(conditions, choices_d_mu)##########
        R = 10 + d_mu * mu_diff + b * a + b * S + (2 + 1*delta*d_S)* error_tmp  ####
    if typ==2:
        R = 10 + mu_diff + b * a + b * S + (2 + 1*delta)* error_tmp+mu_diff*S ####
    if typ==3:
        R = 10 + mu_diff + b * a + b * np.cos(S) + (2 + 1*delta)* error_tmp+mu_diff*np.cos(S)  ####
    if typ==4:
        R = 10 + mu_diff + b * a + b * S**2 + (2 + 1*delta)* error_tmp+mu_diff*S**2  ####
    if typ==5:
        R = 10 + mu_diff + b * a + b * abs(S) + (2 + 1*delta)* error_tmp+ mu_diff * abs(S)####
    return R

# Next state generation function in Python
def next_S_gen(S, a, t, d, p_s):
    S_next = 0.25 * S + 0.2 * a + (2 + d * p_s) * np.random.randn()
    return S_next

# Data generation function in Python
def data_gen(n, nT, prob, d, b=0, mu_diff=0, TI=1, typ=0, heavy_tail=False,df=1):
    """
    typ: Type of reward function
    """
    data = []
    error_sigma = 0.5 ** np.abs(np.subtract.outer(np.arange(1, nT+1), np.arange(1, nT+1)))###
    
    if heavy_tail:
        errors = multivariate_t(df=df, loc=np.zeros(nT), shape=error_sigma).rvs(size=n).reshape([-1,1])  # Heavy tail
    else:
        errors = multivariate_normal(mean=np.zeros(nT), cov=error_sigma).rvs(size=n).reshape([-1,1])  # Normal distribution
    
    if TI == nT:     
        A1 = np.ones(TI)
        A2 = np.zeros(TI)
    else:
        A1 = np.tile(np.repeat([0, 1], TI//2), nT//TI//2)
        A2 = np.tile(np.repeat([1, 0], TI//2), nT//TI//2)
    
    for i in range(n):
        next_S = np.random.randn()  
        
        for t in range(1, nT+1):
            S = next_S
            if prob is None:
                A = 1 if (i % 2 == 0 and t % 2 == 0) or (i % 2 != 0 and (t+1) % 2 == 0) else 0
            elif prob == 3:   
                A = A1[t-1] if i % 2 == 0 else A2[t-1]
            elif prob == 4:
                A = 1 if i <= n / 2 else 0
            else: 
                A = np.random.binomial(1, prob)
            
            next_S = next_S_gen(S, A, t, d, p_s=0)
            R = rwrd_gen(S, A, next_S, errors[i, t-1], t, d, mu_diff=mu_diff, b=b, delta=1, typ=typ)#
            
            data.append([i+1, t, S, A, prob, R, next_S])
    
    columns = ["user", "time", "S1", "action", "prob", "reward", "next.S1"]
    return pd.DataFrame(data, columns=columns)

# Define logit transformation function w(S)
class LogitTransform(nn.Module):
    def __init__(self):
        super(LogitTransform, self).__init__()
        # Define parameters theta and b
        self.theta = nn.Parameter(torch.tensor([0.0]))  # 
        self.b = nn.Parameter(torch.tensor([0.0])) 
    
    def forward(self, S):
        # Define logit transformation: w(S) = 1 / (1 + exp(-(theta * S + b)))
        return 1 / (1 + torch.exp(-(self.theta * S + self.b)))######
    
# Define a neural network model with one hidden layer
class NeuralNetwork(nn.Module):
    def __init__(self, input_size=1, hidden_size=10):
        super(NeuralNetwork, self).__init__()
        # Define network layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.activation = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 1)
        self.output_activation = nn.Sigmoid()############

        # Initialize parameters
        self._initialize_weights()

    def forward(self, S):
        x = self.fc1(S)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.output_activation(x)
        return x

    def _initialize_weights(self):
        # For hidden layer, use He initialization
        nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')##
        nn.init.zeros_(self.fc1.bias)
        # For output layer, use Xavier initialization
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)


def get_loss(df_D, df_H, used_model):########
    n_d = df_D.shape[0]; n_h = df_H.shape[0]
    S1_D = torch.tensor(df_D['S1'].values, dtype=torch.float32).unsqueeze(1)
    S1_H = torch.tensor(df_H['S1'].values, dtype=torch.float32).unsqueeze(1)
    dr_est_t0 = torch.tensor(df_D['dr_est_t0'].values, dtype=torch.float32).unsqueeze(1)
    dr_est_t1 = torch.tensor(df_D['dr_est_t1'].values, dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_1 = torch.tensor(df_D['eta_est_for_D_part_1'].values, dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_2 = torch.tensor(df_H['eta_est_for_D_part_2'].values, dtype=torch.float32).unsqueeze(1)
    
    # Calculate w(S1) for df_D and df_H
    w_D = used_model(S1_D)####
    w_H = used_model(S1_H)##这里

 
    term_1 = w_D * dr_est_t0 - w_D * eta_est_for_D_part_1

   
    term_2 = w_H * eta_est_for_D_part_2
    

    loss_bias = torch.mean(term_1) - torch.mean(term_2)
    

    loss_vector_D = (dr_est_t1 - dr_est_t0) + term_1 


    variance_D = torch.var(loss_vector_D)
    variance_H = torch.var(term_2)

    loss_var =  variance_D/n_d +  variance_H/n_h
    return loss_bias ** 2 + loss_var   
    

def get_loss_ub(df_D, df_H, used_model):
    n_d = df_D.shape[0]; n_h = df_H.shape[0]
    S1_D = torch.tensor(df_D['S1'].values, dtype=torch.float32).unsqueeze(1)
    S1_H = torch.tensor(df_H['S1'].values, dtype=torch.float32).unsqueeze(1)
    dr_est_t0 = torch.tensor(df_D['dr_est_t0'].values, dtype=torch.float32).unsqueeze(1)
    dr_est_t1 = torch.tensor(df_D['dr_est_t1'].values, dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_1 = torch.tensor(df_D['eta_est_for_D_part_1'].values, dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_2 = torch.tensor(df_H['eta_est_for_D_part_2'].values, dtype=torch.float32).unsqueeze(1)

    w_D = used_model(S1_D)
    w_H = used_model(S1_H)

    term_1 = w_D * (dr_est_t0 - eta_est_for_D_part_1)


    term_2 = w_H * eta_est_for_D_part_2###

    loss_mean_bias = torch.mean(term_1) - torch.mean(term_2)
    UB_1 = torch.var(term_1)/n_d######
    UB_2 = torch.var(term_2)/n_h
    
    bias_square_UB_for_D = torch.mean(loss_mean_bias)**2 + 1.64 * (UB_1 + UB_2)##

    
    # Calculate loss vector
    loss_vector_D = (dr_est_t1 - dr_est_t0) + term_1
    
    # Calculate sample variance
    UB_var_1 = torch.std(loss_vector_D**2)/np.sqrt(n_d)
    UB_var_2 = torch.std(term_2**2)/np.sqrt(n_h)
    
    var_UB_for_D_1 = torch.var(loss_vector_D) + 1.64 * UB_var_1#
    var_UB_for_D_2 = torch.var(term_2) + 1.64 * UB_var_2
    
    return bias_square_UB_for_D + var_UB_for_D_1/n_d + var_UB_for_D_2/n_h  

def calculate_ATE(df_D, df_H, used_model):
    # Convert df_D and df_H data to PyTorch tensor
    S1_D = torch.tensor(df_D['S1'], dtype=torch.float32).unsqueeze(1)
    dr_est_t1 = torch.tensor(df_D['dr_est_t1'], dtype=torch.float32).unsqueeze(1)
    dr_est_t0 = torch.tensor(df_D['dr_est_t0'], dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_1 = torch.tensor(df_D['eta_est_for_D_part_1'], dtype=torch.float32).unsqueeze(1)######
    S1_H = torch.tensor(df_H['S1'], dtype=torch.float32).unsqueeze(1)
    eta_est_for_D_part_2 = torch.tensor(df_H['eta_est_for_D_part_2'], dtype=torch.float32).unsqueeze(1)###
    
    w_S1_D = used_model(S1_D)  # 
    w_S1_H = used_model(S1_H)  
    

    term1 = dr_est_t1.mean()  
    term2 = ((1 - w_S1_D) * dr_est_t0).mean() 
    term3 = (w_S1_D * eta_est_for_D_part_1).mean() 
    term4 = (w_S1_H * eta_est_for_D_part_2).mean()  
    # Final ATE calculation formula
    ATE = term1 - term2 - term3 - term4
    return ATE   #######



def optimize_parameters(used_model, loss_func, df_D, df_H, lr=0.1, epochs=100, tol=1e-6):
   
    optimizer = torch.optim.Adam(used_model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)

    prev_loss = 999
    for epoch in range(epochs):  
        optimizer.zero_grad()  #
        loss = loss_func(df_D, df_H, used_model)
        loss.backward()  
        optimizer.step()  
        running_loss = loss.item()
        dist = np.abs(prev_loss - running_loss)
#      
        if dist < tol:
            break
        prev_loss = running_loss * 1.0        

    ATE = calculate_ATE(df_D, df_H, used_model)
    return ATE.item()